{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "395e5929",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "#全局变量\n",
    "hub_token = open('/root/hub_token.txt').read().strip()\n",
    "repo_id = 'lansinuote/diffusion.7.control_net'\n",
    "push_to_hub = True\n",
    "checkpoint = 'runwayml/stable-diffusion-v1-5'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "69854652",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration lansinuote--diffusion.7.control_net-93b816e0abd9e137\n",
      "Found cached dataset parquet (/root/.cache/huggingface/datasets/lansinuote___parquet/lansinuote--diffusion.7.control_net-93b816e0abd9e137/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Dataset({\n",
      "    features: ['image', 'conditioning_image', 'text'],\n",
      "    num_rows: 50000\n",
      "}) {'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7F9695DF5CD0>, 'conditioning_image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=512x512 at 0x7F9695DF5DF0>, 'text': 'pale golden rod circle with old lace background'}\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "if push_to_hub:\n",
    "    #加载数据集\n",
    "    #不能用map直接编码,否则数据量太大,没法处理\n",
    "    dataset = load_dataset('fusing/fill50k')\n",
    "    dataset.push_to_hub(repo_id=repo_id, token=hub_token)\n",
    "\n",
    "#使用我转存的数据集\n",
    "dataset = load_dataset(path=repo_id, split='train')\n",
    "\n",
    "print(dataset, dataset[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "43037e7c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(50000,\n",
       " {'input_ids': tensor([[49406, 16033,  7117,   593, 16776,  5994, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407, 49407,\n",
       "           49407, 49407, 49407, 49407, 49407, 49407, 49407]]),\n",
       "  'pixel_values': tensor([[[[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            ...,\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],\n",
       "  \n",
       "           [[1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            ...,\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000],\n",
       "            [1.0000, 1.0000, 1.0000,  ..., 1.0000, 1.0000, 1.0000]],\n",
       "  \n",
       "           [[0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824],\n",
       "            [0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824],\n",
       "            [0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824],\n",
       "            ...,\n",
       "            [0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824],\n",
       "            [0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824],\n",
       "            [0.8824, 0.8824, 0.8824,  ..., 0.8824, 0.8824, 0.8824]]]]),\n",
       "  'conditioning_pixel_values': tensor([[[[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            ...,\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "  \n",
       "           [[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            ...,\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.]],\n",
       "  \n",
       "           [[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            ...,\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "            [0., 0., 0.,  ..., 0., 0., 0.]]]])})"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import CLIPTokenizer\n",
    "import torchvision\n",
    "\n",
    "tokenizer = CLIPTokenizer.from_pretrained(checkpoint, subfolder='tokenizer')\n",
    "\n",
    "compose = torchvision.transforms.Compose([\n",
    "    torchvision.transforms.Resize(\n",
    "        512, interpolation=torchvision.transforms.InterpolationMode.BILINEAR),\n",
    "    torchvision.transforms.CenterCrop(512),\n",
    "    torchvision.transforms.ToTensor(),\n",
    "])\n",
    "\n",
    "norm = torchvision.transforms.Normalize([0.5], [0.5])\n",
    "\n",
    "\n",
    "#定义loader\n",
    "def collate_fn(data):\n",
    "    #取数据\n",
    "    text = [i['text'] for i in data]\n",
    "    image = [i['image'] for i in data]\n",
    "    conditioning_image = [i['conditioning_image'] for i in data]\n",
    "\n",
    "    #文字编码\n",
    "    #77 = tokenizer.model_max_length\n",
    "    input_ids = tokenizer.batch_encode_plus(text,\n",
    "                                            max_length=77,\n",
    "                                            padding='max_length',\n",
    "                                            truncation=True,\n",
    "                                            return_tensors=None).input_ids\n",
    "\n",
    "    #图像编码\n",
    "    pixel_values = [norm(compose(i)) for i in image]\n",
    "    conditioning_pixel_values = [compose(i) for i in conditioning_image]\n",
    "\n",
    "    #转tensor\n",
    "    input_ids = torch.LongTensor(input_ids)\n",
    "    pixel_values = torch.stack(pixel_values)\n",
    "    conditioning_pixel_values = torch.stack(conditioning_pixel_values)\n",
    "\n",
    "    return {\n",
    "        'input_ids': input_ids,\n",
    "        'pixel_values': pixel_values,\n",
    "        'conditioning_pixel_values': conditioning_pixel_values\n",
    "    }\n",
    "\n",
    "\n",
    "loader = torch.utils.data.DataLoader(dataset,\n",
    "                                     shuffle=True,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     batch_size=1)\n",
    "\n",
    "len(loader), next(iter(loader))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "217c0a9c",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "encoder 12306.048\n",
      "vae 8365.3863\n",
      "unet 85952.0964\n",
      "controlnet 36127.912\n"
     ]
    }
   ],
   "source": [
    "from diffusers import AutoencoderKL, UNet2DConditionModel\n",
    "from transformers import CLIPTextModel\n",
    "from transformers import PretrainedConfig\n",
    "\n",
    "#加载3个模型\n",
    "encoder = CLIPTextModel.from_pretrained(checkpoint, subfolder='text_encoder')\n",
    "vae = AutoencoderKL.from_pretrained(checkpoint, subfolder='vae')\n",
    "unet = UNet2DConditionModel.from_pretrained(checkpoint, subfolder='unet')\n",
    "\n",
    "#定义controlnet\n",
    "%run controlnet_model.ipynb\n",
    "controlnet = ControlNet(PretrainedConfig())\n",
    "load_params(controlnet, unet)\n",
    "\n",
    "vae.requires_grad_(False)\n",
    "encoder.requires_grad_(False)\n",
    "unet.requires_grad_(False)\n",
    "\n",
    "\n",
    "def print_model_size(name, model):\n",
    "    print(name, sum(i.numel() for i in model.parameters()) / 10000)\n",
    "\n",
    "\n",
    "print_model_size('encoder', encoder)\n",
    "print_model_size('vae', vae)\n",
    "print_model_size('unet', unet)\n",
    "print_model_size('controlnet', controlnet)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a5f838ac",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DDPMScheduler {\n",
       "   \"_class_name\": \"DDPMScheduler\",\n",
       "   \"_diffusers_version\": \"0.15.1\",\n",
       "   \"beta_end\": 0.012,\n",
       "   \"beta_schedule\": \"scaled_linear\",\n",
       "   \"beta_start\": 0.00085,\n",
       "   \"clip_sample\": false,\n",
       "   \"clip_sample_range\": 1.0,\n",
       "   \"dynamic_thresholding_ratio\": 0.995,\n",
       "   \"num_train_timesteps\": 1000,\n",
       "   \"prediction_type\": \"epsilon\",\n",
       "   \"sample_max_value\": 1.0,\n",
       "   \"set_alpha_to_one\": false,\n",
       "   \"skip_prk_steps\": true,\n",
       "   \"steps_offset\": 1,\n",
       "   \"thresholding\": false,\n",
       "   \"trained_betas\": null,\n",
       "   \"variance_type\": \"fixed_small\"\n",
       " },\n",
       " AdamW (\n",
       " Parameter Group 0\n",
       "     amsgrad: False\n",
       "     betas: (0.9, 0.999)\n",
       "     capturable: False\n",
       "     eps: 1e-08\n",
       "     foreach: None\n",
       "     lr: 1e-05\n",
       "     maximize: False\n",
       "     weight_decay: 0.01\n",
       " ),\n",
       " MSELoss())"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from diffusers import DDPMScheduler\n",
    "\n",
    "scheduler = DDPMScheduler.from_pretrained(checkpoint, subfolder='scheduler')\n",
    "\n",
    "optimizer = torch.optim.AdamW(controlnet.parameters(),\n",
    "                              lr=1e-5,\n",
    "                              betas=(0.9, 0.999),\n",
    "                              weight_decay=0.01,\n",
    "                              eps=1e-8)\n",
    "\n",
    "criterion = torch.nn.MSELoss()\n",
    "\n",
    "scheduler, optimizer, criterion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "a3f32e9f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor(0.2469, grad_fn=<MseLossBackward0>)"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_loss(data):\n",
    "    device = data['input_ids'].device\n",
    "\n",
    "    #文字编码\n",
    "    #[1, 77] -> [1, 77, 768]\n",
    "    out_encoder = encoder(data['input_ids'])[0]\n",
    "\n",
    "    #抽取图像特征图\n",
    "    #[1, 3, 512, 512] -> [1, 4, 64, 64]\n",
    "    out_vae = vae.encode(data['pixel_values']).latent_dist.sample()\n",
    "    #0.18215 = vae.config.scaling_factor\n",
    "    out_vae = out_vae * 0.18215\n",
    "\n",
    "    #随机数,unet的计算目标\n",
    "    noise = torch.randn_like(out_vae)\n",
    "\n",
    "    #往特征图中添加噪声\n",
    "    #1000 = scheduler.num_train_timesteps\n",
    "    #1 = out_vae.shape[0]\n",
    "    noise_step = torch.randint(0, 1000, (1, )).long()\n",
    "    noise_step = noise_step.to(device)\n",
    "    out_vae_noise = scheduler.add_noise(out_vae, noise, noise_step)\n",
    "\n",
    "    #使用controlnet计算unet模型的down和mid部分的数据\n",
    "    #在unet中依然会计算down和mid,但会和controlnet的计算结果相加\n",
    "    #相当于为unet的down和mid的计算结果添加额外信息\n",
    "    out_control_down, out_control_mid = controlnet(\n",
    "        out_vae_noise,\n",
    "        noise_step,\n",
    "        out_encoder,\n",
    "        condition=data['conditioning_pixel_values'])\n",
    "\n",
    "    #根据文字信息,把特征图中的噪声计算出来\n",
    "    out_unet = unet(out_vae_noise,\n",
    "                    noise_step,\n",
    "                    encoder_hidden_states=out_encoder,\n",
    "                    down_block_additional_residuals=out_control_down,\n",
    "                    mid_block_additional_residual=out_control_mid).sample\n",
    "\n",
    "    #计算mse loss\n",
    "    #[1, 4, 64, 64],[1, 4, 64, 64]\n",
    "    return criterion(out_unet, noise)\n",
    "\n",
    "\n",
    "get_loss({\n",
    "    'input_ids': torch.ones(1, 77).long(),\n",
    "    'pixel_values': torch.randn(1, 3, 512, 512),\n",
    "    'conditioning_pixel_values': torch.randn(1, 3, 512, 512),\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "84b88599",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.0026870747096836567\n",
      "2000 3.461793488706462\n",
      "4000 3.049983833108854\n",
      "6000 2.88362369704555\n",
      "8000 3.005527748755412\n",
      "10000 2.7510610005374474\n",
      "12000 2.7980021113398834\n",
      "14000 2.4720400111327763\n",
      "16000 2.62037534315823\n",
      "18000 2.375480920953123\n",
      "20000 2.5328880913584726\n",
      "22000 2.540079788388539\n",
      "24000 2.484456200958448\n",
      "26000 2.3819653876125813\n",
      "28000 2.3381337551472825\n",
      "30000 2.383196011167456\n",
      "32000 2.33385793850357\n",
      "34000 2.2283505542582134\n",
      "36000 2.0874762790735986\n",
      "38000 2.211367720625276\n",
      "40000 2.3298714868487878\n",
      "42000 2.2381660530427325\n",
      "44000 2.1768586739563034\n",
      "46000 2.202571557407282\n",
      "48000 2.176738911841312\n"
     ]
    }
   ],
   "source": [
    "from diffusers import StableDiffusionPipeline\n",
    "\n",
    "\n",
    "def train():\n",
    "    device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "    unet.to(device)\n",
    "    encoder.to(device)\n",
    "    vae.to(device)\n",
    "    controlnet.to(device)\n",
    "    controlnet.train()\n",
    "\n",
    "    loss_sum = 0\n",
    "    for i, data in enumerate(loader):\n",
    "        for k in data.keys():\n",
    "            data[k] = data[k].to(device)\n",
    "\n",
    "        loss = get_loss(data) / 4\n",
    "        loss.backward()\n",
    "        loss_sum += loss.item()\n",
    "\n",
    "        if i % 4 == 0:\n",
    "            torch.nn.utils.clip_grad_norm_(controlnet.parameters(), 1.0)\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "        if i % 2000 == 0:\n",
    "            print(i, loss_sum)\n",
    "            loss_sum = 0\n",
    "\n",
    "    #保存到本地\n",
    "    torch.save(controlnet.cpu(), './save/controlnet.model')\n",
    "\n",
    "\n",
    "if push_to_hub:\n",
    "    train()\n",
    "    #保存到hub\n",
    "    controlnet.push_to_hub(repo_id=repo_id, use_auth_token=hub_token)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt39]",
   "language": "python",
   "name": "conda-env-pt39-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
